Bootstrapping model fits

The previous section describes fitting a single model. But we may also want to have confidence estimates for the fit. We can do that via bootstrapping the data set.

The overall recommended workflow is to first fit models to all the data to determine the number of epitopes, etc. Then once the desired fitting parameters are determined, you can bootstrap to get confidence on predictions.

Get model fit to the data

The first step is just to fit a Polyclonal model to all the data we are using. We do similar to the previous notebook for our RBD example, but first shrink the size of the data set to just 7500 variants to provide more “error” to better illustrate the bootstrapping.

We will call this model fit to all the data we are using the “root” model as it’s used as the starting point (root) for the subsequent bootstrapping. Note that data (which we will bootstrap) are attached to this pre-fit model:

[1]:
# NBVAL_IGNORE_OUTPUT

import pandas as pd

import polyclonal

# read the data, and just make "barcode" the numerical rank of the variants
noisy_data = (
    pd.read_csv("RBD_variants_escape_noisy.csv", na_filter=None)
    .query('library == "avg3muts"')
    .query("concentration in [0.25, 1, 4]")
    .sort_values(["concentration", "aa_substitutions"])
    .reset_index(drop=True)
    .assign(barcode=lambda x: x.groupby("concentration").cumcount())
)

# just keep some variants to make fitting "noisier"
n_keep = 7500
barcodes_to_keep = (
    noisy_data["barcode"].drop_duplicates().sample(n_keep, random_state=1).tolist()
)
noisy_data = noisy_data.query("barcode in @barcodes_to_keep")

# make and fit the root Polyclonal object with all the data we are using
root_poly = polyclonal.Polyclonal(
    data_to_fit=noisy_data,
    activity_wt_df=pd.DataFrame.from_records(
        [
            ("class 1", 1.0),
            ("class 2", 3.0),
            ("class 3", 2.0),
        ],
        columns=["epitope", "activity"],
    ),
    site_escape_df=pd.DataFrame.from_records(
        [
            ("class 1", 417, 10.0),
            ("class 2", 484, 10.0),
            ("class 3", 444, 10.0),
        ],
        columns=["epitope", "site", "escape"],
    ),
    data_mut_escape_overlap="fill_to_data",
)

opt_res = root_poly.fit(logfreq=100)
# First fitting site-level model.
# Starting optimization of 522 parameters at Sat Mar 19 09:34:43 2022.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.022838       4506     4505.7    0.29701          0
        100     2.5614     550.09     546.34     3.7432          0
        200     4.9081     541.57     537.02     4.5554          0
        300      7.272        539     533.84     5.1659          0
        400     9.5925     538.27      532.9     5.3674          0
        500     11.941     537.67     532.23     5.4371          0
        600     14.214     537.13     531.63     5.5078          0
        700     16.518     536.64     530.85     5.7896          0
        800     18.892     536.31      530.4     5.9137          0
        900      21.26     536.06     530.11      5.956          0
       1000     23.576     535.51      529.4     6.1068          0
       1100     25.904     535.05      528.9     6.1524          0
       1200     28.217     534.73     528.47     6.2531          0
       1300     30.477     534.43     528.04     6.3902          0
       1400     32.749     534.36     527.89     6.4741          0
       1500     35.087     534.33     527.77     6.5579          0
       1558     36.507     534.32     527.71     6.6146          0
# Successfully finished at Sat Mar 19 09:35:20 2022.
# Starting optimization of 5799 parameters at Sat Mar 19 09:35:20 2022.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.027019     643.27      566.6     76.667  1.589e-29
        100     2.8415     323.66     237.57     75.125      10.97
        200      5.609     310.33     223.52     70.913       15.9
        300      8.423     300.99     220.06      64.13       16.8
        400     11.172     286.66     215.39     52.955      18.31
        500     13.872     276.87     209.89     47.828     19.156
        600     16.618     271.09     205.02     45.623     20.445
        700     19.431     265.34     199.01     44.579     21.751
        800     22.216     260.79     192.97     44.793     23.031
        900     24.935     258.05     188.63     45.423     24.004
       1000     27.775     255.18      183.5     46.481     25.201
       1100     30.532     252.88      179.9     47.376     25.609
       1200     33.198     251.35     177.86     47.671     25.813
       1300      35.91     249.55     174.99      48.19     26.365
       1400     38.565     247.69     172.67      48.51     26.513
       1500     41.315     246.33     171.11     48.751     26.472
       1600     43.945     245.85     170.46      48.93     26.455
       1700     46.531     245.54     169.81     49.079     26.647
       1800     49.215     245.09     169.05     49.252     26.788
       1900     52.001      244.8      168.4     49.421     26.982
       2000     54.759     244.57     168.41     49.371     26.794
       2100      57.55      244.4     168.07     49.405     26.926
       2200     60.373     244.25     167.91     49.429     26.915
       2300     63.072     244.16     167.75      49.47     26.941
       2400     65.807     244.06     167.59     49.501     26.979
       2500     68.558     243.94     167.38      49.51     27.046
       2600     71.229     243.83     167.26     49.516     27.049
       2700     74.031     243.75     167.16     49.522      27.07
       2800     76.775      243.7     167.13      49.52     27.051
       2900     79.589     243.67     167.04      49.52     27.112
       3000     82.234     243.61     167.06     49.506     27.044
       3100     84.976     243.57     167.08     49.505      26.99
       3200     87.754     243.53     167.02     49.545     26.973
       3300     90.538      243.5     167.04     49.555     26.906
       3400     93.249     243.48     167.01     49.575     26.897
       3500     95.987     243.45     166.99     49.587      26.87
       3596     98.538     243.42     166.97     49.586     26.866
# Successfully finished at Sat Mar 19 09:36:59 2022.

Create and fit bootstrapped models

To create the bootstrapped models, we initialize a PolyclonalCollection, here just using 5 samples for speed (for real analyses to get good error estimates you may want more on the order of 20 to 100 bootstrap samples). Note it is important that the root model you are using has already been fit to the data! Note also that there is a n_threads option which specifies how many threads should be used for the bootstrapping: by default it’s -1 (use all CPUs available), but set to another number if you want to limit CPU usage:

[2]:
n_bootstrap_samples = 5

bootstrap_poly = polyclonal.PolyclonalCollection(
    root_polyclonal=root_poly,
    n_bootstrap_samples=n_bootstrap_samples,
)

Now fit the bootstrapped models:

[3]:
# NBVAL_IGNORE_OUTPUT

import time

start = time.time()
print(f"Starting fitting bootstrap models at {time.asctime()}")
n_fit, n_failed = bootstrap_poly.fit_models()
print(f"Fitting took {time.time() - start:.3g} seconds, finished at {time.asctime()}")
assert n_failed == 0 and n_fit == n_bootstrap_samples
Starting fitting bootstrap models at Sat Mar 19 09:37:00 2022
Fitting took 65.2 seconds, finished at Sat Mar 19 09:38:06 2022

Look at summarized results

We can get the resulting measurements for the epitope activities and mutation effects both per-replicate and summarized across replicates (mean, median, standard deviation).

Epitope activities

Epitope activities for each replicate:

[4]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_df_replicates.round(1)
[4]:
epitope activity bootstrap_replicate
0 class 1 2.0 1
1 class 2 2.6 1
2 class 3 2.1 1
3 class 1 1.9 2
4 class 2 2.7 2
5 class 3 1.9 2
6 class 1 2.1 3
7 class 2 2.5 3
8 class 3 2.0 3
9 class 1 1.9 4
10 class 2 2.7 4
11 class 3 2.0 4
12 class 1 2.1 5
13 class 2 2.6 5
14 class 3 1.9 5

Epitope activities summarized across replicates. The std column gives the standard deviation:

[5]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_df.round(1)
[5]:
epitope mean median std
0 class 1 2.0 2.0 0.1
1 class 2 2.6 2.6 0.1
2 class 3 2.0 2.0 0.1

We can plot the epitope activities summarized across replicates. The dropdown allows you to choose the summary stat (mean, median), and the black lines indicate the standard deviation. Mouse over for values:

[6]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.activity_wt_barplot()
[6]:

Mutation escape values

Mutation escape values for each replicate:

[7]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_df_replicates.round(1).head()
[7]:
epitope site wildtype mutant mutation escape bootstrap_replicate
0 class 1 331 N A N331A 0.4 1
1 class 1 331 N D N331D -0.4 1
2 class 1 331 N E N331E 0.3 1
3 class 1 331 N F N331F 0.1 1
4 class 1 331 N G N331G 0.2 1

Mutation escape values summarizes across replicates. Note the frac_bootstrap_replicates column has the fraction of bootstrap replicates with a value for this mutation:

[8]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_df.round(1).head(n=3)
[8]:
epitope site wildtype mutant mutation mean median std n_bootstrap_replicates frac_bootstrap_replicates
0 class 1 331 N A N331A 0.2 0.2 0.4 5 1.0
1 class 1 331 N D N331D -0.2 -0.1 0.2 5 1.0
2 class 1 331 N E N331E -0.0 0.0 0.3 5 1.0

We can plot the mutation escape values across replicates. The dropdown selects the statistic shown in the heatmap (mean or median), and mouseovers give details on points. Here we set min_frac_bootstrap_replicates=0.9 to only report escape values observed in at least 90% of bootstrap replicates (this gets rid of rare mutations):

[9]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_heatmap(min_frac_bootstrap_replicates=0.9)
[9]:

Site summaries of mutation escape

Site summaries of mutation escape values for replicates:

[10]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_site_summary_df_replicates.round(1).head()
[10]:
epitope site wildtype mean total positive max min total negative bootstrap_replicate
0 class 1 331 N 0.5 8.9 1.8 -0.7 -1.3 1
1 class 1 332 I 0.6 10.6 1.5 0.0 0.0 1
2 class 1 333 T 0.5 9.4 1.3 -0.7 -0.9 1
3 class 1 334 N 0.8 13.9 1.9 -0.2 -0.3 1
4 class 1 335 L 0.5 9.5 1.5 -0.5 -0.8 1

Site summaries of mutation escape values summarized (e.g., averaged) across replicates. Note that the metric column now indicates a different row for each site-summary metric type, which is then summarized by its mean, median, and standard deviation:

[11]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_site_summary_df.round(1).head()
[11]:
epitope site wildtype metric mean median std n_bootstrap_replicates frac_bootstrap_replicates
0 class 1 331 N max 1.6 1.7 0.2 5 1.0
1 class 1 331 N mean 0.5 0.5 0.2 5 1.0
2 class 1 331 N min -0.5 -0.5 0.2 5 1.0
3 class 1 331 N total negative -1.1 -1.3 0.7 5 1.0
4 class 1 331 N total positive 8.4 8.9 2.0 5 1.0

We can plot site summaries of the mutation escape. Note that there is an option to toggle on/off the error bars (standard deviations) and show what metric is shown (e.g., mean effect of mutation, total positive escape at site, etc) as well as how that metric is summarize (mean, median):

[12]:
# NBVAL_IGNORE_OUTPUT
bootstrap_poly.mut_escape_lineplot(min_frac_bootstrap_replicates=0.9)
[12]:

Some tests

Below are just tests for approximate consistency of results with what is expected:

[13]:
sites = [417, 446, 484, 501]  # just test these sites
for attr, atol in [
    ("activity_wt_df", 0.5),
    ("mut_escape_site_summary_df", 1.0),
    ("mut_escape_df", 1.0),
]:
    print(f"Testing {attr}")
    df = getattr(bootstrap_poly, attr).round(1).drop(columns="std")
    if "site" in df.columns:
        df = df.query("site in @sites").reset_index(drop=True)
    f = f"RBD_bootstrap_expected_{attr}.csv"
    expected = pd.read_csv(f).drop(columns="std")
    pd.testing.assert_frame_equal(
        df,
        expected,
        atol=atol,
        rtol=0.2,
        obj=f"{attr} DataFrame",
    )
Testing activity_wt_df
Testing mut_escape_site_summary_df
Testing mut_escape_df
[ ]: